use std::fmt;
use std::io::{Read, Write};
use std::sync::{Arc, Mutex};

use polars_utils::arena::Node;
#[cfg(feature = "serde")]
use polars_utils::pl_serialize;
use polars_utils::unique_id::UniqueId;
use recursive::recursive;
#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};

use super::*;

// DSL format version in a form of (Major, Minor).
//
// It is no longer needed to increment this. We use the schema hashes to check for compatibility.
//
// Only increment if you need to make a breaking change that doesn't change the schema hashes.
pub const DSL_VERSION: (u16, u16) = (24, 0);
const DSL_MAGIC_BYTES: &[u8] = b"DSL_VERSION";

const DSL_SCHEMA_HASH: SchemaHash<'static> = SchemaHash::from_hash_file();

#[derive(Debug)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "dsl-schema", derive(schemars::JsonSchema))]
pub enum DslPlan {
    #[cfg(feature = "python")]
    PythonScan {
        options: crate::dsl::python_dsl::PythonOptionsDsl,
    },
    /// Filter on a boolean mask
    Filter {
        input: Arc<DslPlan>,
        predicate: Expr,
    },
    /// Cache the input at this point in the LP
    Cache {
        input: Arc<DslPlan>,
        id: UniqueId,
    },
    Scan {
        sources: ScanSources,
        unified_scan_args: Box<UnifiedScanArgs>,
        scan_type: Box<FileScanDsl>,
        /// Local use cases often repeatedly collect the same `LazyFrame` (e.g. in interactive notebook use-cases),
        /// so we cache the IR conversion here, as the path expansion can be quite slow (especially for cloud paths).
        /// We don't have the arena, as this is always a source node.
        #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
        cached_ir: Arc<Mutex<Option<IR>>>,
    },
    // we keep track of the projection and selection as it is cheaper to first project and then filter
    /// In memory DataFrame
    DataFrameScan {
        df: Arc<DataFrame>,
        schema: SchemaRef,
    },
    /// Polars' `select` operation, this can mean projection, but also full data access.
    Select {
        expr: Vec<Expr>,
        input: Arc<DslPlan>,
        options: ProjectionOptions,
    },
    /// Groupby aggregation
    GroupBy {
        input: Arc<DslPlan>,
        keys: Vec<Expr>,
        predicates: Vec<Expr>,
        aggs: Vec<Expr>,
        maintain_order: bool,
        options: Arc<GroupbyOptions>,
        apply: Option<(PlanCallback<DataFrame, DataFrame>, SchemaRef)>,
    },
    /// Join operation
    Join {
        input_left: Arc<DslPlan>,
        input_right: Arc<DslPlan>,
        // Invariant: left_on and right_on are equal length.
        left_on: Vec<Expr>,
        right_on: Vec<Expr>,
        // Invariant: Either left_on/right_on or predicates is set (non-empty).
        predicates: Vec<Expr>,
        options: Arc<JoinOptions>,
    },
    /// Adding columns to the table without a Join
    HStack {
        input: Arc<DslPlan>,
        exprs: Vec<Expr>,
        options: ProjectionOptions,
    },
    /// Match / Evolve into a schema
    MatchToSchema {
        input: Arc<DslPlan>,
        /// The schema to match to.
        ///
        /// This is also always the output schema.
        match_schema: SchemaRef,

        per_column: Arc<[MatchToSchemaPerColumn]>,

        extra_columns: ExtraColumnsPolicy,
    },
    PipeWithSchema {
        input: Arc<[DslPlan]>,
        callback: PlanCallback<(Vec<DslPlan>, Vec<SchemaRef>), DslPlan>,
    },
    #[cfg(feature = "pivot")]
    Pivot {
        input: Arc<DslPlan>,
        on: Selector,
        on_columns: Arc<DataFrame>,
        index: Selector,
        values: Selector,
        agg: Expr,
        maintain_order: bool,
        separator: PlSmallStr,
    },
    /// Remove duplicates from the table
    Distinct {
        input: Arc<DslPlan>,
        options: DistinctOptionsDSL,
    },
    /// Sort the table
    Sort {
        input: Arc<DslPlan>,
        by_column: Vec<Expr>,
        slice: Option<(i64, usize)>,
        sort_options: SortMultipleOptions,
    },
    /// Slice the table
    Slice {
        input: Arc<DslPlan>,
        offset: i64,
        len: IdxSize,
    },
    /// A (User Defined) Function
    MapFunction {
        input: Arc<DslPlan>,
        function: DslFunction,
    },
    /// Vertical concatenation
    Union {
        inputs: Vec<DslPlan>,
        args: UnionArgs,
    },
    /// Horizontal concatenation of multiple plans
    HConcat {
        inputs: Vec<DslPlan>,
        options: HConcatOptions,
    },
    /// This allows expressions to access other tables
    ExtContext {
        input: Arc<DslPlan>,
        contexts: Vec<DslPlan>,
    },
    Sink {
        input: Arc<DslPlan>,
        payload: SinkType,
    },
    SinkMultiple {
        inputs: Vec<DslPlan>,
    },
    #[cfg(feature = "merge_sorted")]
    MergeSorted {
        input_left: Arc<DslPlan>,
        input_right: Arc<DslPlan>,
        key: PlSmallStr,
    },
    IR {
        // Keep the original Dsl around as we need that for serialization.
        dsl: Arc<DslPlan>,
        version: u32,
        #[cfg_attr(any(feature = "serde", feature = "dsl-schema"), serde(skip))]
        node: Option<Node>,
    },
}

impl Clone for DslPlan {
    // Autogenerated by rust-analyzer, don't care about it looking nice, it just
    // calls clone on every member of every enum variant.
    #[rustfmt::skip]
    #[allow(clippy::clone_on_copy)]
    #[recursive]
    fn clone(&self) -> Self {
        match self {
            #[cfg(feature = "python")]
            Self::PythonScan { options } => Self::PythonScan { options: options.clone() },
            Self::Filter { input, predicate } => Self::Filter { input: input.clone(), predicate: predicate.clone() },
            Self::Cache { input, id } => Self::Cache { input: input.clone(), id: *id },
            Self::Scan { sources,  unified_scan_args, scan_type, cached_ir } => Self::Scan { sources: sources.clone(), unified_scan_args: unified_scan_args.clone(), scan_type: scan_type.clone(), cached_ir: cached_ir.clone() },
            Self::DataFrameScan { df, schema, } => Self::DataFrameScan { df: df.clone(), schema: schema.clone(),  },
            Self::Select { expr, input, options } => Self::Select { expr: expr.clone(), input: input.clone(), options: options.clone() },
            Self::GroupBy { input, keys, predicates, aggs, apply, maintain_order, options } => Self::GroupBy { input: input.clone(), keys: keys.clone(), predicates: predicates.clone(), aggs: aggs.clone(), apply: apply.clone(), maintain_order: maintain_order.clone(), options: options.clone() },
            Self::Join { input_left, input_right, left_on, right_on, predicates, options } => Self::Join { input_left: input_left.clone(), input_right: input_right.clone(), left_on: left_on.clone(), right_on: right_on.clone(), options: options.clone(), predicates: predicates.clone() },
            Self::HStack { input, exprs, options } => Self::HStack { input: input.clone(), exprs: exprs.clone(),  options: options.clone() },
            Self::MatchToSchema { input, match_schema, per_column, extra_columns } => Self::MatchToSchema { input: input.clone(), match_schema: match_schema.clone(), per_column: per_column.clone(), extra_columns: *extra_columns },
            Self::PipeWithSchema { input, callback } => Self::PipeWithSchema { input: input.clone(), callback: callback.clone() },
            Self::Distinct { input, options } => Self::Distinct { input: input.clone(), options: options.clone() },
            Self::Sort {input,by_column, slice, sort_options } => Self::Sort { input: input.clone(), by_column: by_column.clone(), slice: slice.clone(), sort_options: sort_options.clone() },
            Self::Slice { input, offset, len } => Self::Slice { input: input.clone(), offset: offset.clone(), len: len.clone() },
            Self::MapFunction { input, function } => Self::MapFunction { input: input.clone(), function: function.clone() },
            Self::Union { inputs, args} => Self::Union { inputs: inputs.clone(), args: args.clone() },
            Self::HConcat { inputs, options } => Self::HConcat { inputs: inputs.clone(), options: options.clone() },
            Self::ExtContext { input, contexts, } => Self::ExtContext { input: input.clone(), contexts: contexts.clone() },
            Self::Sink { input, payload } => Self::Sink { input: input.clone(), payload: payload.clone() },
            Self::SinkMultiple { inputs } => Self::SinkMultiple { inputs: inputs.clone() },
            #[cfg(feature = "pivot")]
            Self::Pivot { input, on, on_columns, index, values, agg, separator, maintain_order }  => Self::Pivot { input: input.clone(), on: on.clone(), on_columns: on_columns.clone(), index: index.clone(), values: values.clone(), agg: agg.clone(), separator: separator.clone(), maintain_order: *maintain_order },
            #[cfg(feature = "merge_sorted")]
            Self::MergeSorted { input_left, input_right, key } => Self::MergeSorted { input_left: input_left.clone(), input_right: input_right.clone(), key: key.clone() },
            Self::IR {node, dsl, version} => Self::IR {node: *node, dsl: dsl.clone(), version: *version},
        }
    }
}

impl Default for DslPlan {
    fn default() -> Self {
        let df = DataFrame::empty();
        let schema = df.schema().clone();
        DslPlan::DataFrameScan {
            df: Arc::new(df),
            schema,
        }
    }
}

#[derive(Default, Clone, Copy)]
pub struct PlanSerializationContext {
    pub use_cloudpickle: bool,
}

impl DslPlan {
    pub fn describe(&self) -> PolarsResult<String> {
        Ok(self.clone().to_alp()?.describe())
    }

    pub fn describe_tree_format(&self) -> PolarsResult<String> {
        Ok(self.clone().to_alp()?.describe_tree_format())
    }

    pub fn display(&self) -> PolarsResult<impl fmt::Display> {
        struct DslPlanDisplay(IRPlan);
        impl fmt::Display for DslPlanDisplay {
            fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
                fmt::Display::fmt(&self.0.as_ref().display(), f)
            }
        }
        Ok(DslPlanDisplay(self.clone().to_alp()?))
    }

    pub fn to_alp(self) -> PolarsResult<IRPlan> {
        let mut lp_arena = Arena::with_capacity(16);
        let mut expr_arena = Arena::with_capacity(16);

        let node = to_alp(
            self,
            &mut expr_arena,
            &mut lp_arena,
            &mut OptFlags::default(),
        )?;
        let plan = IRPlan::new(node, lp_arena, expr_arena);

        Ok(plan)
    }

    #[cfg(feature = "serde")]
    pub fn serialize_versioned<W: Write>(
        &self,
        mut writer: W,
        ctx: PlanSerializationContext,
    ) -> PolarsResult<()> {
        let le_major = DSL_VERSION.0.to_le_bytes();
        let le_minor = DSL_VERSION.1.to_le_bytes();

        // @GB:
        // This is absolute horrendous but serde does not allow for state to passed along with the
        // serialization so there is no proper way to do this except replace serde.
        polars_utils::pl_serialize::USE_CLOUDPICKLE.set(ctx.use_cloudpickle);

        writer.write_all(DSL_MAGIC_BYTES)?;
        writer.write_all(&le_major)?;
        writer.write_all(&le_minor)?;
        writer.write_all(DSL_SCHEMA_HASH.as_bytes())?;
        let serializable_plan = serializable_plan::SerializableDslPlan::from(self);
        pl_serialize::serialize_dsl(writer, &serializable_plan)
            .map_err(|e| polars_err!(ComputeError: "serialization failed\n\nerror: {e}"))
    }

    #[cfg(feature = "serde")]
    pub fn deserialize_versioned<R: Read>(mut reader: R) -> PolarsResult<Self> {
        const MAGIC_LEN: usize = DSL_MAGIC_BYTES.len();
        let mut version_magic = [0u8; MAGIC_LEN + 4];
        reader
            .read_exact(&mut version_magic)
            .map_err(|e| polars_err!(ComputeError: "failed to read incoming DSL_VERSION: {e}"))?;

        if &version_magic[..MAGIC_LEN] != DSL_MAGIC_BYTES {
            polars_bail!(ComputeError: "dsl magic bytes not found")
        }

        let major = u16::from_le_bytes(version_magic[MAGIC_LEN..MAGIC_LEN + 2].try_into().unwrap());
        let minor = u16::from_le_bytes(
            version_magic[MAGIC_LEN + 2..MAGIC_LEN + 4]
                .try_into()
                .unwrap(),
        );

        const MAJOR: u16 = DSL_VERSION.0;
        const MINOR: u16 = DSL_VERSION.1;

        if polars_core::config::verbose() {
            eprintln!(
                "incoming DSL_VERSION: {major}.{minor}, deserializer DSL_VERSION: {MAJOR}.{MINOR}"
            );
        }

        if major != MAJOR {
            polars_bail!(ComputeError:
                "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
                "error: can't deserialize DSL with a different major version"
            );
        }

        if minor > MINOR {
            polars_bail!(ComputeError:
                "deserialization failed\n\ngiven DSL_VERSION: {major}.{minor} is not compatible with this Polars version which uses DSL_VERSION: {MAJOR}.{MINOR}\n{}",
                "error: can't deserialize DSL with a higher minor version"
            );
        }

        let mut schema_hash = [0_u8; SCHEMA_HASH_LEN];
        reader.read_exact(&mut schema_hash).map_err(
            |e| polars_err!(ComputeError: "failed to read incoming DSL_SCHEMA_HASH: {e}"),
        )?;

        let incoming_hash = SchemaHash::new(&schema_hash).ok_or_else(
            || polars_err!(ComputeError: "failed to read incoming DSL schema hash, not a valid hex string")
        )?;

        if polars_core::config::verbose() {
            eprintln!(
                "incoming DSL_SCHEMA_HASH: {incoming_hash}, deserializer DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}"
            );
        }

        if std::env::var("POLARS_SKIP_DSL_HASH_VERIFICATION").as_deref() != Ok("1")
            && incoming_hash != DSL_SCHEMA_HASH
        {
            polars_bail!(ComputeError:
                "deserialization failed\n\ngiven DSL_SCHEMA_HASH: {incoming_hash} is not compatible with this Polars version which uses DSL_SCHEMA_HASH: {DSL_SCHEMA_HASH}\n{}",
                "error: can't deserialize DSL with incompatible schema"
            );
        }

        let serializable_plan: serializable_plan::SerializableDslPlan =
            pl_serialize::deserialize_dsl(reader)
                .map_err(|e| polars_err!(ComputeError: "deserialization failed\n\nerror: {e}"))?;
        (&serializable_plan).try_into()
    }

    #[cfg(feature = "dsl-schema")]
    pub fn dsl_schema() -> schemars::Schema {
        use schemars::Schema;
        use schemars::generate::SchemaSettings;
        use schemars::transform::{Transform, transform_subschemas};

        #[derive(Clone, Copy, Debug)]
        struct MyTransform;

        impl Transform for MyTransform {
            fn transform(&mut self, schema: &mut Schema) {
                // Remove descriptions auto-generated from doc comments
                schema.remove("description");

                transform_subschemas(self, schema);
            }
        }

        let mut schema = SchemaSettings::default()
            .with_transform(MyTransform)
            .into_generator()
            .into_root_schema_for::<DslPlan>();

        // Add the DSL schema hash as a top level field
        schema.insert("hash".into(), DSL_SCHEMA_HASH.to_string().into());

        schema
    }
}

const SCHEMA_HASH_LEN: usize = 64;

struct SchemaHash<'a>(&'a str);

impl SchemaHash<'static> {
    const fn from_hash_file() -> Self {
        // Generated by build.rs
        let bytes = include_bytes!(concat!(env!("OUT_DIR"), "/dsl-schema.sha256"));
        Self::new(bytes).expect("not a valid hex string")
    }
}

impl<'a> SchemaHash<'a> {
    const fn new(bytes: &'a [u8; SCHEMA_HASH_LEN]) -> Option<Self> {
        let mut i = 0;
        while i < bytes.len() {
            if !bytes[i].is_ascii_hexdigit() {
                return None;
            };
            i += 1;
        }
        match str::from_utf8(bytes) {
            Ok(hash) => Some(Self(hash)),
            Err(_) => unreachable!(),
        }
    }

    fn as_bytes(&self) -> &'a [u8; SCHEMA_HASH_LEN] {
        self.0.as_bytes().try_into().unwrap()
    }
}

impl PartialEq for SchemaHash<'_> {
    fn eq(&self, other: &Self) -> bool {
        self.0.eq_ignore_ascii_case(other.0)
    }
}

impl std::fmt::Display for SchemaHash<'_> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "{}", self.0)
    }
}
